{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3b05af3b",
   "metadata": {},
   "source": [
    "(tune-mnist-keras)=\n",
    "\n",
    "# Using Keras & TensorFlow with Tune\n",
    "\n",
    "<a id=\"try-anyscale-quickstart-tune_mnist_keras\" href=\"https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=tune_mnist_keras\">\n",
    "    <img src=\"../../_static/img/run-on-anyscale.svg\" alt=\"try-anyscale-quickstart\">\n",
    "</a>\n",
    "<br></br>\n",
    "\n",
    "```{image} /images/tf_keras_logo.jpeg\n",
    ":align: center\n",
    ":alt: Keras & TensorFlow Logo\n",
    ":height: 120px\n",
    ":target: https://keras.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "- `pip install \"ray[tune]\" tensorflow==2.18.0 filelock`\n",
    "\n",
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "19e3c389",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2025-02-13 15:22:41</td></tr>\n",
       "<tr><td>Running for: </td><td>00:00:41.76        </td></tr>\n",
       "<tr><td>Memory:      </td><td>21.4/36.0 GiB      </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using AsyncHyperBand: num_stopped=0<br>Bracket: Iter 320.000: None | Iter 80.000: None | Iter 20.000: None<br>Logical resource usage: 2.0/12 CPUs, 0/0 GPUs\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name             </th><th>status    </th><th>loc            </th><th style=\"text-align: right;\">  hidden</th><th style=\"text-align: right;\">  learning_rate</th><th style=\"text-align: right;\">  momentum</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  accuracy</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>train_mnist_533a2_00000</td><td>TERMINATED</td><td>127.0.0.1:36365</td><td style=\"text-align: right;\">     371</td><td style=\"text-align: right;\">     0.0799367 </td><td style=\"text-align: right;\">  0.588387</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         20.8515</td><td style=\"text-align: right;\">  0.984583</td></tr>\n",
       "<tr><td>train_mnist_533a2_00001</td><td>TERMINATED</td><td>127.0.0.1:36364</td><td style=\"text-align: right;\">     266</td><td style=\"text-align: right;\">     0.0457424 </td><td style=\"text-align: right;\">  0.22303 </td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         19.5277</td><td style=\"text-align: right;\">  0.96495 </td></tr>\n",
       "<tr><td>train_mnist_533a2_00002</td><td>TERMINATED</td><td>127.0.0.1:36368</td><td style=\"text-align: right;\">     157</td><td style=\"text-align: right;\">     0.0190286 </td><td style=\"text-align: right;\">  0.537132</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         16.6606</td><td style=\"text-align: right;\">  0.95385 </td></tr>\n",
       "<tr><td>train_mnist_533a2_00003</td><td>TERMINATED</td><td>127.0.0.1:36363</td><td style=\"text-align: right;\">     451</td><td style=\"text-align: right;\">     0.0433488 </td><td style=\"text-align: right;\">  0.18925 </td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         22.0514</td><td style=\"text-align: right;\">  0.966283</td></tr>\n",
       "<tr><td>train_mnist_533a2_00004</td><td>TERMINATED</td><td>127.0.0.1:36367</td><td style=\"text-align: right;\">     276</td><td style=\"text-align: right;\">     0.0336728 </td><td style=\"text-align: right;\">  0.430171</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         20.0884</td><td style=\"text-align: right;\">  0.964767</td></tr>\n",
       "<tr><td>train_mnist_533a2_00005</td><td>TERMINATED</td><td>127.0.0.1:36366</td><td style=\"text-align: right;\">     208</td><td style=\"text-align: right;\">     0.071015  </td><td style=\"text-align: right;\">  0.419166</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         17.933 </td><td style=\"text-align: right;\">  0.976083</td></tr>\n",
       "<tr><td>train_mnist_533a2_00006</td><td>TERMINATED</td><td>127.0.0.1:36475</td><td style=\"text-align: right;\">     312</td><td style=\"text-align: right;\">     0.00692959</td><td style=\"text-align: right;\">  0.714595</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         13.058 </td><td style=\"text-align: right;\">  0.944017</td></tr>\n",
       "<tr><td>train_mnist_533a2_00007</td><td>TERMINATED</td><td>127.0.0.1:36479</td><td style=\"text-align: right;\">     169</td><td style=\"text-align: right;\">     0.0694114 </td><td style=\"text-align: right;\">  0.664904</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         10.7991</td><td style=\"text-align: right;\">  0.9803  </td></tr>\n",
       "<tr><td>train_mnist_533a2_00008</td><td>TERMINATED</td><td>127.0.0.1:36486</td><td style=\"text-align: right;\">     389</td><td style=\"text-align: right;\">     0.0370836 </td><td style=\"text-align: right;\">  0.665592</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         14.018 </td><td style=\"text-align: right;\">  0.977833</td></tr>\n",
       "<tr><td>train_mnist_533a2_00009</td><td>TERMINATED</td><td>127.0.0.1:36487</td><td style=\"text-align: right;\">     389</td><td style=\"text-align: right;\">     0.0676138 </td><td style=\"text-align: right;\">  0.52372 </td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">         14.0043</td><td style=\"text-align: right;\">  0.981833</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-13 15:22:41,843\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/exp' in 0.0048s.\n",
      "2025-02-13 15:22:41,846\tINFO tune.py:1041 -- Total run time: 41.77 seconds (41.75 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyperparameters found were: {'threads': 2, 'learning_rate': 0.07993666231835218, 'momentum': 0.5883866709655042, 'hidden': 371} | Accuracy: 0.98458331823349\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from filelock import FileLock\n",
    "from tensorflow.keras.datasets import mnist\n",
    "\n",
    "from ray import tune\n",
    "from ray.tune.schedulers import AsyncHyperBandScheduler\n",
    "from ray.air.integrations.keras import ReportCheckpointCallback\n",
    "\n",
    "\n",
    "def train_mnist(config):\n",
    "    # https://github.com/tensorflow/tensorflow/issues/32159\n",
    "    import tensorflow as tf\n",
    "\n",
    "    batch_size = 128\n",
    "    num_classes = 10\n",
    "    epochs = 12\n",
    "\n",
    "    with FileLock(os.path.expanduser(\"~/.data.lock\")):\n",
    "        (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "    x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "    model = tf.keras.models.Sequential(\n",
    "        [\n",
    "            tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "            tf.keras.layers.Dense(config[\"hidden\"], activation=\"relu\"),\n",
    "            tf.keras.layers.Dropout(0.2),\n",
    "            tf.keras.layers.Dense(num_classes, activation=\"softmax\"),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    model.compile(\n",
    "        loss=\"sparse_categorical_crossentropy\",\n",
    "        optimizer=tf.keras.optimizers.SGD(learning_rate=config[\"learning_rate\"], momentum=config[\"momentum\"]),\n",
    "        metrics=[\"accuracy\"],\n",
    "    )\n",
    "\n",
    "    model.fit(\n",
    "        x_train,\n",
    "        y_train,\n",
    "        batch_size=batch_size,\n",
    "        epochs=epochs,\n",
    "        verbose=0,\n",
    "        validation_data=(x_test, y_test),\n",
    "        callbacks=[ReportCheckpointCallback(metrics={\"accuracy\": \"accuracy\"})],\n",
    "    )\n",
    "\n",
    "\n",
    "def tune_mnist():\n",
    "    sched = AsyncHyperBandScheduler(\n",
    "        time_attr=\"training_iteration\", max_t=400, grace_period=20\n",
    "    )\n",
    "\n",
    "    tuner = tune.Tuner(\n",
    "        tune.with_resources(train_mnist, resources={\"cpu\": 2, \"gpu\": 0}),\n",
    "        tune_config=tune.TuneConfig(\n",
    "            metric=\"accuracy\",\n",
    "            mode=\"max\",\n",
    "            scheduler=sched,\n",
    "            num_samples=10,\n",
    "        ),\n",
    "        run_config=tune.RunConfig(\n",
    "            name=\"exp\",\n",
    "            stop={\"accuracy\": 0.99},\n",
    "        ),\n",
    "        param_space={\n",
    "            \"threads\": 2,\n",
    "            \"learning_rate\": tune.uniform(0.001, 0.1),\n",
    "            \"momentum\": tune.uniform(0.1, 0.9),\n",
    "            \"hidden\": tune.randint(32, 512),\n",
    "        },\n",
    "    )\n",
    "    results = tuner.fit()\n",
    "    return results\n",
    "\n",
    "    \n",
    "\n",
    "results = tune_mnist()\n",
    "print(f\"Best hyperparameters found were: {results.get_best_result().config} | Accuracy: {results.get_best_result().metrics['accuracy']}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d7e46189",
   "metadata": {},
   "source": [
    "This should output something like:\n",
    "\n",
    "```\n",
    "Best hyperparameters found were:  {'threads': 2, 'learning_rate': 0.07607440973606909, 'momentum': 0.7715363277240616, 'hidden': 452} | Accuracy: 0.98458331823349\n",
    "```\n",
    "\n",
    "## More Keras and TensorFlow Examples\n",
    "\n",
    "- {doc}`/tune/examples/includes/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT.\n",
    "- {doc}`/tune/examples/includes/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune\n",
    "  with the Trainable. This uses `tf.function`.\n",
    "  Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced\n",
    "- {doc}`/tune/examples/includes/pbt_tune_cifar10_with_keras`:\n",
    "  A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tune-keras",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
